# The code implementation refers to the following files from pytorch:
# - https://github.com/pytorch/pytorch/blob/v1.13.0/torch/optim/optimizer.py
# Additional modifications are made by Huawei Technologies Co., Ltd in 2023.
# ============================================================================
"""optimizer"""
from __future__ import absolute_import
from collections import defaultdict
from typing import Iterable
from mindspore import ops

from mindspore.nn.cell import Cell
from mindspore.common.parameter import Parameter, ParameterTuple
from mindspore.common import Tensor
import mindspore.common.dtype as mstype
from mindspore import _checkparam as validator
from mindspore import log as logger


__all__ = ['Optimizer']


class Optimizer(Cell):
    r"""
    Base class for all optimizers.

    .. warning::
        This is an experimental optimizer API that is subject to change.
        This module must be used with lr scheduler module in `LRScheduler Class
        <https://www.mindspore.cn/docs/en/master/api_python/mindspore.experimental.html#lrscheduler-class>`_ .

    Args:
        params (Union[list(Parameter), list(dict)]): an iterable of :class:`mindspore.Parameter` or
            dict. Specifies what Tensors should be optimized.
        defaults (dict): a dict containing default values of optimization
            options (used when a parameter group doesn't specify them).

    Supported Platforms:
        ``Ascend`` ``GPU`` ``CPU``

    Examples:
        >>> import numpy as np
        >>> import mindspore
        >>> from mindspore import nn, Tensor, Parameter
        >>> from mindspore import ops
        >>> from mindspore.experimental import optim
        >>>
        >>> class MySGD(optim.Optimizer):
        ...    def __init__(self, params, lr):
        ...        defaults = dict(lr=lr)
        ...        super(MySGD, self).__init__(params, defaults)
        ...
        ...    def construct(self, gradients):
        ...         for group_id, group in enumerate(self.param_groups):
        ...            id = self.group_start_id[group_id]
        ...            for i, param in enumerate(group["params"]):
        ...                next_param = param + gradients[id+i] * group["lr"]
        ...                ops.assign(param, next_param)
        >>>
        >>> net = nn.Dense(8, 2)
        >>> data = Tensor(np.random.rand(20, 8).astype(np.float32))
        >>> label = Tensor(np.random.rand(20, 2).astype(np.float32))
        >>>
        >>> optimizer = MySGD(net.trainable_params(), 0.01)
        >>> optimizer.add_param_group({"params": Parameter([0.01, 0.02])})
        >>>
        >>> criterion = nn.MAELoss(reduction="mean")
        >>>
        >>> def forward_fn(data, label):
        ...    logits = net(data)
        ...    loss = criterion(logits, label)
        ...    return loss, logits
        >>>
        >>> grad_fn = mindspore.value_and_grad(forward_fn, None, optimizer.parameters, has_aux=True)
        >>>
        >>> def train_step(data, label):
        ...    (loss, _), grads = grad_fn(data, label)
        ...    optimizer(grads)
        ...    print(loss)
        >>>
        >>> train_step(data, label)
    """
    def __init__(self, params, defaults):
        super(Optimizer, self).__init__(auto_prefix=False)

        param_groups = self._parameters_base_check(params, "params")
        self.defaults = defaults
        self.state = defaultdict(dict)
        self.param_groups = []
        self.parameters = []
        self.lrs = []
        self.map_ = ops.Map()
        self.group_start_id = [0]
        if not isinstance(param_groups[0], dict):
            param_groups = [{'params': param_groups}]

        for param_group in param_groups:
            self.add_param_group(param_group)
        self.parameters = ParameterTuple(self.parameters)
        self.hyper_map = ops.HyperMap()
        self.enable_tuple_broaden = True

    def __repr__(self):
        format_string = self.__class__.__name__ + ' ('
        for i, group in enumerate(self.param_groups):
            format_string += '\n'
            format_string += 'Parameter Group {0}\n'.format(i)
            for key in sorted(group.keys()):
                if key != 'params':
                    format_string += '    {0}: {1}\n'.format(key, group[key].value()) \
                        if key == "lr" and isinstance(group[key], Parameter) \
                        else '    {0}: {1}\n'.format(key, group[key])
        format_string += ')'
        return format_string

    def add_param_group(self, param_group):
        r"""
        Add a param group to the `Optimizer.param_groups`.

        Args:
            param_group (dict): Specifies what Parameters should be optimized along with group
                specific optimization options.
        """
        group_id = len(self.param_groups)
        param_group = self._preprocess_param_group(param_group)
        self.parameters += tuple(param_group.get("params"))

        for name, default in self.defaults.items():
            if name not in param_group:
                param_group.setdefault(name, default)

        lr = self._build_single_lr(param_group.get("lr"), 'learning_rate_group_' + str(group_id))
        weight_decay = self._preprocess_weight_decay(param_group.get("weight_decay", 0.0))
        self.lrs.append(lr)
        param_group["lr"] = lr
        param_group["weight_decay"] = weight_decay
        if "amsgrad" in param_group and param_group.get("amsgrad") and hasattr(self, 'max_v_group'):
            param_items = ParameterTuple(tuple(param_group.get("params")))
            param_group["max_exp_avg_sq"] = param_items.clone(prefix="max_exp_avg_sq", init='zeros')
        self.param_groups.append(param_group)
        self.group_start_id.append(self.group_start_id[-1] + len(param_group.get("params")))

    @staticmethod
    def _parameters_base_check(parameters, param_info):
        """Parameters base check."""
        if parameters is None:
            raise ValueError(f"For 'Optimizer', the argument {param_info} can not be None.")
        if not isinstance(parameters, Iterable):
            raise TypeError(f"For 'Optimizer', the argument {param_info} must be Iterable type, "
                            f"but got {type(parameters)}.")
        parameters = list(parameters)

        if not parameters:
            raise ValueError(f"For 'Optimizer', the argument {param_info} must not be empty.")
        return parameters

    def _decay_weight(self, weight_decay, params, gradients):
        """Apply weight decay."""
        if weight_decay != 0.:
            weight_decay = Tensor(weight_decay, mstype.float32)
            gradients = self.map_(ops.partial(_apply_decay, weight_decay), params, gradients)
        return gradients

    def _preprocess_param_group(self, param_group):
        """Preprocess param groups."""
        if not isinstance(param_group, dict):
            raise TypeError('Param group must be a dict.')

        params = param_group['params']
        if isinstance(params, Parameter):
            param_group['params'] = [params]
        elif isinstance(params, set):
            raise TypeError('Optimizer parameters need to be organized in ordered collections, but '
                            'the ordering of tensors in sets will change between runs. '
                            'Please use a list instead.')
        else:
            param_group['params'] = list(params)

        for param in param_group['params']:
            if not isinstance(param, Parameter):
                raise TypeError("Optimizer can only optimize Parameters, but one of the params is " + type(param))

        if len(param_group['params']) != len(set(param_group['params'])):
            logger.warning("Optimizer contains a parameter group with duplicate parameters.")

        param_set = set()
        for group in self.param_groups:
            param_set.update(set(group['params']))
        if not param_set.isdisjoint(set(param_group['params'])):
            raise ValueError("some parameters appear in more than one parameter group.")
        return param_group

    def _build_single_lr(self, learning_rate, name):
        """Check lr value, and convert lr to a float or a Tensor."""
        if isinstance(learning_rate, (float, int)):
            learning_rate = float(learning_rate)
            validator.check_non_negative_float(learning_rate, "learning rate", self.cls_name)
            return Parameter(Tensor(learning_rate, mstype.float32), name)

        if isinstance(learning_rate, Tensor):
            if learning_rate.ndim == 0:
                return Parameter(learning_rate.astype(mstype.float32), name)
            raise ValueError(f"For 'Optimizer', if 'learning_rate' is a Tensor, "
                             f"then it should be scalar Tensor")

        raise TypeError("For 'Optimizer', the argument 'learning_rate' must be int, float or Tensor, "
                        "but got {}.".format(type(learning_rate)))

    def _preprocess_weight_decay(self, weight_decay):
        """preprocess weight decay"""
        if isinstance(weight_decay, (float, int)):
            weight_decay = float(weight_decay)
            validator.check_non_negative_float(weight_decay, "weight_decay", self.cls_name)
        else:
            raise TypeError("For 'Optimizer', the argument 'Weight_decay' must be int or "
                            "float.but got {}".format(type(weight_decay)))
        return weight_decay

    def construct(self, *hyper_params):
        raise NotImplementedError

op_add = ops.AddN()
op_gather = ops.Gather()
op_mul = ops.Mul()

_apply_decay = ops.MultitypeFuncGraph("apply_decay")


@_apply_decay.register("Tensor", "Tensor", "RowTensor")
def _tensor_apply_decay_with_sparse(weight_decay, weight, gradient):
    """Get grad with weight_decay."""
    indices = gradient.indices
    values = op_add((op_gather(weight, indices, 0) * ops.cast(weight_decay, ops.dtype(weight)), gradient.values))
    shape = gradient.dense_shape
    return RowTensorInner(indices, values, shape)


@_apply_decay.register("Tensor", "Tensor", "Tensor")
def _tensor_apply_decay(weight_decay, weight, gradient):
    """Get grad with weight_decay."""
    return op_add((op_mul(weight, ops.cast(weight_decay, ops.dtype(weight))), gradient))


def check_not_less_than(arg_value, arg_name, prim, value=0.0):
    if arg_value < value:
        raise ValueError("For {}, the {} must be greater than or equal to {}, "
                         "but got {}.".format(prim, arg_name, value, arg_value))


def check_not_less_than_without_equal(arg_value, arg_name, prim, value=0.0):
    if arg_value <= value:
        raise ValueError("For {}, the {} must be greater than {}, "
                         "but got {}.".format(prim, arg_name, value, arg_value))
